package com.itwenke.springbootdemo.shirosimple.common;

import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

public class MultiThreadingTransactionManager {

    /**
     * 事务管理器
     */
    private final PlatformTransactionManager platformTransactionManager;

    /**
     * 超时时间
     */
    private final long timeout;

    /**
     * 时间单位
     */
    private final TimeUnit unit;

    /**
     * 主线程门闩：当所有的子线程准备完成时，通知主线程判断统一”提交“还是”回滚”
     */
    private final CountDownLatch mainStageLatch = new CountDownLatch(1);

    /**
     * 子线程门闩：count 为0时，说明子线程都已准备完成了
     */
    private CountDownLatch childStageLatch = null;

    /**
     * 是否提交事务
     */
    private final AtomicBoolean isSubmit = new AtomicBoolean(true);

    /**
     * 构造方法
     *
     * @param platformTransactionManager 事务管理器
     * @param timeout 超时时间
     * @param unit 时间单位
     */
    public MultiThreadingTransactionManager(PlatformTransactionManager platformTransactionManager, long timeout, TimeUnit unit) {
        this.platformTransactionManager = platformTransactionManager;
        this.timeout = timeout;
        this.unit = unit;
    }

    /**
     * 任务执行器
     *
     * @param tasks 任务列表
     * @param executorService 线程池
     * @return 是否执行成功
     */
    public boolean execute(List<Runnable> tasks, ThreadPoolTaskExecutor executorService) {
        // 排查null空值
        tasks.removeAll(Collections.singleton(null));

        // 属性初始化
        init(tasks.size());

        for (Runnable task : tasks) {
            // 创建线程
            Thread thread = new Thread(() -> {
                // 判断其它线程是否已经执行任务失败，失败就不执行了
                if (!isSubmit.get()) {
                    childStageLatch.countDown();
                }

                // 开启事务
                DefaultTransactionDefinition defaultTransactionDefinition = new DefaultTransactionDefinition();
                TransactionStatus transactionStatus = platformTransactionManager.getTransaction(defaultTransactionDefinition);
                try {
                    // 执行任务
                    task.run();
                } catch (Exception e) {
                    // 任务执行失败，设置回滚
                    isSubmit.set(false);
                }
                // 计数器减一
                childStageLatch.countDown();

                try {
                    // 等待主线程的指示，判断统一”提交“还是”回滚”
                    mainStageLatch.await();
                    if (isSubmit.get()) {
                        // 提交
                        platformTransactionManager.commit(transactionStatus);
                    } else {
                        // 回滚
                        platformTransactionManager.rollback(transactionStatus);
                    }
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            });
            // 线程池执行任务
            executorService.execute(thread);
        }

        try {
            // 主线程等待所有子线程准备完成，避免死锁，设置超时时间
            childStageLatch.await(timeout, unit);
            long count = childStageLatch.getCount();
            // 主线程等待超时，子线程可能发生长时间阻塞，死锁
            if (count > 0) {
                // 设置回滚
                isSubmit.set(false);
            }
            // 主线程通知子线程”提交“还是”回滚”
            mainStageLatch.countDown();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        // 返回执行结果是否成功
        return isSubmit.get();
    }

    /**
     * 属性初始化
     * @param size 任务数量
     */
    private void init(int size) {
        childStageLatch = new CountDownLatch(size);
    }
}
